{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ce9295b2-182b-490b-8325-83a67c4a001d",
   "metadata": {},
   "source": [
    "# Chapter 4: Implementing a GPT model from Scratch To Generate Text \n",
    "\n",
    "## (Notes are in progress ...)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7da97ed-e02f-4d7f-b68e-a0eba3716e02",
   "metadata": {},
   "source": [
    "- In this chapter, we implement the architecture of a GPT-like LLM; in the next chapter, we will train this LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
   "metadata": {},
   "source": [
    "## 4.1 Coding the decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5ed66875-1f24-445d-add6-006aae3c5707",
   "metadata": {},
   "outputs": [],
   "source": [
    "GPT_CONFIG = {\n",
    "    \"vocab_size\": 50257,  # Vocabulary size\n",
    "    \"ctx_len\": 1024,      # Context length\n",
    "    \"emb_dim\": 768,       # Embedding dimension\n",
    "    \"n_heads\": 12,        # Number of attention heads\n",
    "    \"n_layers\": 12,       # Number of layers\n",
    "    \"drop_rate\": 0.1,     # Dropout rate\n",
    "    \"qkv_bias\": True      # Query-Key-Value bias\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class DummyGPTModel(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
    "        self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n",
    "        self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
    "        \n",
    "        # Use a placeholder for TransformerBlock\n",
    "        self.trf_blocks = nn.Sequential(\n",
    "            *[DummyTransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
    "        \n",
    "        # Use a placeholder for LayerNorm\n",
    "        self.final_norm = DummyLayerNorm(cfg[\"emb_dim\"])\n",
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n",
    "\n",
    "    def forward(self, in_idx):\n",
    "        batch_size, seq_len = in_idx.shape\n",
    "        tok_embeds = self.tok_emb(in_idx)\n",
    "        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
    "        x = tok_embeds + pos_embeds\n",
    "        x = self.drop_emb(x)\n",
    "        x = self.trf_blocks(x)\n",
    "        x = self.final_norm(x)\n",
    "        logits = self.out_head(x)\n",
    "        return logits\n",
    "\n",
    "\n",
    "class DummyTransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        # A simple placeholder\n",
    "\n",
    "    def forward(self, x):\n",
    "        # This block does nothing and just returns its input.\n",
    "        return x\n",
    "\n",
    "\n",
    "class DummyLayerNorm(nn.Module):\n",
    "    def __init__(self, normalized_shape, eps=1e-5):\n",
    "        super().__init__()\n",
    "        # The parameters here are just to mimic the LayerNorm interface.\n",
    "\n",
    "    def forward(self, x):\n",
    "        # This layer does nothing and just returns its input.\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 6109,  3626,  6100,   345,  2651,    13],\n",
       "        [ 6109,  1110,  6622,   257, 11483,    13]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tiktoken\n",
    "import torch\n",
    "\n",
    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
    "\n",
    "batch = []\n",
    "\n",
    "txt1 = \"Every effort moves you forward.\"\n",
    "txt2 = \"Every day holds a lesson.\"\n",
    "\n",
    "batch.append(torch.tensor(tokenizer.encode(txt1)))\n",
    "batch.append(torch.tensor(tokenizer.encode(txt2)))\n",
    "batch = torch.stack(batch, dim=0)\n",
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "009238cd-0160-4834-979c-309710986bb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output shape: torch.Size([2, 6, 50257])\n",
      "tensor([[[-1.2034,  0.3201, -0.7130,  ..., -1.5548, -0.2390, -0.4667],\n",
      "         [-0.1192,  0.4539, -0.4432,  ...,  0.2392,  1.3469,  1.2430],\n",
      "         [ 0.5307,  1.6720, -0.4695,  ...,  1.1966,  0.0111,  0.5835],\n",
      "         [ 0.0139,  1.6755, -0.3388,  ...,  1.1586, -0.0435, -1.0400],\n",
      "         [ 0.0106, -1.6711,  0.7797,  ...,  0.3561, -0.0867, -0.5452],\n",
      "         [ 0.1821,  1.1189,  0.1641,  ...,  1.9012,  1.2240,  0.8853]],\n",
      "\n",
      "        [[-1.0341,  0.2765, -1.1252,  ..., -0.8381,  0.0773,  0.1147],\n",
      "         [-0.2632,  0.5427, -0.2828,  ...,  0.1357,  0.3707,  1.3615],\n",
      "         [ 0.9695,  1.2466, -0.3515,  ..., -0.0171, -0.3478,  0.2616],\n",
      "         [-0.0237, -0.7329,  0.3184,  ...,  1.5946, -0.1334, -0.2981],\n",
      "         [-0.1876, -0.7909,  0.8811,  ...,  1.1121, -0.3781, -1.4438],\n",
      "         [ 0.0405,  1.2000,  0.0702,  ...,  1.4740,  1.1567,  1.2077]]],\n",
      "       grad_fn=<UnsafeViewBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "model = DummyGPTModel(GPT_CONFIG)\n",
    "\n",
    "out = model(batch)\n",
    "print(\"Output shape:\", out.shape)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62598daa-f819-40da-95ca-899988b6f8da",
   "metadata": {},
   "source": [
    "## 4.2 Normalizing activations with LayerNorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3333a305-aa3d-460a-bcce-b80662d464d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LayerNorm(nn.Module):\n",
    "    def __init__(self, emb_dim):\n",
    "        super().__init__()\n",
    "        self.eps = 1e-5\n",
    "        self.scale = nn.Parameter(torch.ones(emb_dim))\n",
    "        self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean = x.mean(-1, keepdim=True)\n",
    "        var = x.var(-1, keepdim=True, unbiased=False)\n",
    "        norm_x = (x - mean) / torch.sqrt(var + self.eps)\n",
    "        return self.scale * norm_x + self.shift"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd9d772b-c833-4a5c-9d58-9b208d2a0b68",
   "metadata": {},
   "source": [
    "## 4.3 Adding GeLU activation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9275c879-b148-4579-a107-86827ca14d4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GELU(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / torch.pi)) *\n",
    "                                          (x + 0.044715 * x ** 3)))\n",
    "\n",
    "\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n",
    "            GELU(),\n",
    "            nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n",
    "            nn.Dropout(cfg[\"drop_rate\"])\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
   "metadata": {},
   "source": [
    "## 4.4 Understanding shortcut connections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "05473938-799c-49fd-86d4-8ed65f94fee6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.1785],\n",
       "        [-0.0278],\n",
       "        [-0.5737],\n",
       "        [-1.5400],\n",
       "        [ 0.1513]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ExampleWithShortcut(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(10, 10)\n",
    "        self.fc2 = nn.Linear(10, 10)\n",
    "        self.fc3 = nn.Linear(10, 1)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x)) + identity # Shortcut connection\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "torch.manual_seed(123)\n",
    "ex_short = ExampleWithShortcut()\n",
    "inputs = torch.randn(5, 10)\n",
    "ex_short(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cae578ca-e564-42cf-8635-a2267047cdff",
   "metadata": {},
   "source": [
    "## 4.5 Connecting attention and linear layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from previous_chapters import MultiHeadAttention\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.att = MultiHeadAttention(\n",
    "            d_in=cfg[\"emb_dim\"],\n",
    "            d_out=cfg[\"emb_dim\"],\n",
    "            block_size=cfg[\"ctx_len\"],\n",
    "            num_heads=cfg[\"n_heads\"], \n",
    "            dropout=cfg[\"drop_rate\"],\n",
    "            qkv_bias=cfg[\"qkv_bias\"])\n",
    "        self.ff = FeedForward(cfg)\n",
    "        self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n",
    "        self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n",
    "        self.drop_resid = nn.Dropout(cfg[\"drop_rate\"])\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.drop_resid(self.att(self.norm1(x)))\n",
    "        x = x + self.drop_resid(self.ff(self.norm2(x)))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPTModel(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
    "        self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n",
    "        \n",
    "        # Use a placeholder for TransformerBlock\n",
    "        self.trf_blocks = nn.Sequential(\n",
    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
    "        \n",
    "        # Use a placeholder for LayerNorm\n",
    "        self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n",
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n",
    "\n",
    "    def forward(self, in_idx):\n",
    "        batch_size, seq_len = in_idx.shape\n",
    "        tok_embeds = self.tok_emb(in_idx)\n",
    "        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
    "        x = tok_embeds + pos_embeds\n",
    "        x = self.trf_blocks(x)\n",
    "        x = self.final_norm(x)\n",
    "        logits = self.out_head(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "252b78c2-4404-483b-84fe-a412e55c16fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output shape: torch.Size([2, 6, 50257])\n",
      "tensor([[[-0.7971, -0.6232, -0.1815,  ...,  0.1020, -0.0916,  0.1885],\n",
      "         [ 0.5491, -0.5220,  0.7559,  ..., -0.3137, -0.8780,  0.2182],\n",
      "         [ 0.3107,  0.0346, -0.4637,  ..., -0.3700, -0.4346, -0.0747],\n",
      "         [ 0.5681,  0.3940,  0.5397,  ..., -0.1027,  0.5461,  0.4834],\n",
      "         [-0.2948, -0.1605, -0.5878,  ...,  0.0054, -0.0207, -0.1100],\n",
      "         [-0.3096, -0.7744, -0.0254,  ...,  0.7480,  0.3515,  0.3208]],\n",
      "\n",
      "        [[-0.6910, -0.3758, -0.1458,  ..., -0.1824, -0.5231,  0.0873],\n",
      "         [-0.2562, -0.4204,  1.5507,  ..., -0.7057, -0.3989,  0.0084],\n",
      "         [-0.4263, -0.2257, -0.2074,  ..., -0.2160, -1.1648,  0.4744],\n",
      "         [-0.0245,  1.3792,  0.2234,  ..., -0.7153, -0.7858, -0.3762],\n",
      "         [-0.4696, -0.4584, -0.4812,  ...,  0.5044, -0.8911,  0.1549],\n",
      "         [-0.7727, -0.6125, -0.3203,  ...,  1.0753, -0.0878,  0.2805]]],\n",
      "       grad_fn=<UnsafeViewBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "model = GPTModel(GPT_CONFIG)\n",
    "\n",
    "out = model(batch)\n",
    "print(\"Output shape:\", out.shape)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "84fb8be4-9d3b-402b-b3da-86b663aac33a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters: 163,037,184\n",
      "Number of trainable parameters considering weight tying: 124,439,808\n"
     ]
    }
   ],
   "source": [
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"Total number of parameters: {total_params:,}\")\n",
    "\n",
    "total_params_gpt2 =  total_params - sum(p.numel() for p in model.tok_emb.parameters())\n",
    "print(f\"Number of trainable parameters considering weight tying: {total_params_gpt2:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5131a752-fab8-4d70-a600-e29870b33528",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total size of the model: 621.94 MB\n"
     ]
    }
   ],
   "source": [
    "# Calculate the total size in bytes (assuming float32, 4 bytes per parameter)\n",
    "total_size_bytes = total_params * 4\n",
    "\n",
    "# Convert to megabytes\n",
    "total_size_mb = total_size_bytes / (1024 * 1024)\n",
    "\n",
    "print(f\"Total size of the model: {total_size_mb:.2f} MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
   "metadata": {},
   "source": [
    "## 4.6 Implementing the forward pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07700ec8-32e8-4775-9c13-5c43671d6728",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
